#!/usr/bin/env python
# coding: utf-8
# based on https://github.com/andersbll/nnet
# modified by heibanke

import numpy as np
import scipy as sp
from .layers import ParamMixin, MSECostLayer
from .helpers import one_hot, unhot


class neuralnetwork:
    def __init__(self, layers, cost=MSECostLayer(), rng=None):
        self.layers = layers
        if rng is None:
            rng = np.random.RandomState()
        self.rng = rng
        self.cost = cost

    def _setup(self, X, Y):
        # Setup layers sequentially
        """
        Your Code Here
        """		
        next_shape = X.shape
        for layer in self.layers:
            layer._setup(next_shape, self.rng)
            next_shape = layer.output_shape(next_shape)
#            print(next_shape)
        if next_shape != Y.shape:
            raise ValueError('Output shape %s does not match Y %s'
                             % (next_shape, Y.shape))

    def train(self, X, Y, learning_rate=0.1, max_iter=10, batch_size=64):
        """ Train network on the given data. """
        n_samples = Y.shape[0]
        n_batches = n_samples // batch_size
        Y_one_hot = one_hot(Y)
        self._setup(X, Y_one_hot)
        iter = 0
        # Stochastic gradient descent with mini-batches
        while iter < max_iter:
            iter += 1
            for b in range(n_batches):
                batch_begin = b*batch_size
                batch_end = batch_begin+batch_size
                X_batch = X[batch_begin:batch_end]
                Y_batch = Y_one_hot[batch_begin:batch_end]

                """
                Your Code Following:
                """
                # Forward propagation
                X_next = X_batch
                for layer in self.layers:
                    X_next = layer.fprop(X_next)
                Y_pred = X_next

                # Back propagation of partial derivatives
                next_grad = self.cost.input_grad(Y_batch, Y_pred)
                for layer in reversed(self.layers):
                    next_grad = layer.bprop(next_grad)

                # Update parameters
                for layer in self.layers:
                    if isinstance(layer, ParamMixin):
                        for param, inc in zip(layer.params(),
                                              layer.param_grads()):
                            param -= learning_rate*inc

            # Output training status
            loss = self._loss(X, Y_one_hot)
            error = self.error(X, Y)
            print('iter %i, loss %.4f, train error %.4f' % (iter, loss, error))

    def _loss(self, X, Y_one_hot):
        """
        Your Code Here
        """
        X_next = X
        for layer in self.layers:
            X_next = layer.fprop(X_next)
        Y_pred = X_next
        return self.cost.loss(Y_one_hot, Y_pred)

    def predict(self, X):
        """ Calculate an output Y for the given input X. """
        X_next = X
        for layer in self.layers:
            X_next = layer.fprop(X_next)
        Y_pred = unhot(X_next)
        return Y_pred

    def error(self, X, Y):
        """ Calculate error on the given data. """
        Y_pred = self.predict(X)
        error = Y_pred != Y
        return np.mean(error)



"""
    def train_scipy(self,X,y):
        #scipy实现的训练
        
        m,n = X.shape
        import scipy.optimize
        options = {'maxiter': 1000, 'disp': True}
        J = lambda wb: self.forward_backward(wb, X, y)
        theta = self.wb_init(X,y) 

        result = opt.minimize(J, theta, method='L-BFGS-B', jac=True, options=options)
        self.theta = result.x
"""
